{
 "cells": [
  {
   "cell_type": "markdown",
   "id": "9bfdaff2-6478-4fc7-a717-c56cb5b8b1c2",
   "metadata": {},
   "source": [
    "# Defining a Custom Property Graph Retriever\n",
    "\n",
    "<a href=\"https://colab.research.google.com/github/run-llama/llama_index/blob/main/docs/examples/property_graph/property_graph_custom_retriever.ipynb\" target=\"_parent\"><img src=\"https://colab.research.google.com/assets/colab-badge.svg\" alt=\"Open In Colab\"/></a>\n",
    "\n",
    "\n",
    "This guide shows you how to define a custom retriever against a property graph.\n",
    "\n",
    "It is more involved than using our out-of-the-box graph retrievers, but allows you to have granular control over the retrieval process so that it's better tailored for your application. \n",
    "\n",
    "We show you how to define an advanced retrieval flow by directly leveraging the property graph store. We'll execute both vector search and text-to-cypher retrieval, and then combine the results through a reranking module."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "c0bcf776-7e25-4f3b-9cf7-edd954cede01",
   "metadata": {},
   "outputs": [],
   "source": [
    "%pip install llama-index\n",
    "%pip install llama-index-graph-stores-neo4j\n",
    "%pip install llama-index-postprocessor-cohere-rerank"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "6159b260-0061-4323-897f-f12e261da235",
   "metadata": {},
   "source": [
    "## Setup and Build the Property Graph"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "6389fe2a-6573-4847-a7ca-756b3f94d34f",
   "metadata": {},
   "outputs": [],
   "source": [
    "import nest_asyncio\n",
    "\n",
    "nest_asyncio.apply()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "fa4d0d28-d20a-4378-8f3b-2f0ed776d2ac",
   "metadata": {},
   "outputs": [],
   "source": [
    "import os\n",
    "\n",
    "os.environ[\"OPENAI_API_KEY\"] = \"sk-...\""
   ]
  },
  {
   "cell_type": "markdown",
   "id": "12dc1433-c1e6-4a77-9c3b-2b4e52bae23d",
   "metadata": {},
   "source": [
    "#### Load Paul Graham Essay"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "57333740-ec57-42d6-ae60-8fc3fa58c504",
   "metadata": {},
   "outputs": [],
   "source": [
    "!mkdir -p 'data/paul_graham/'\n",
    "!wget 'https://raw.githubusercontent.com/run-llama/llama_index/main/docs/examples/data/paul_graham/paul_graham_essay.txt' -O 'data/paul_graham/paul_graham_essay.txt'"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "1685f988-9d8f-4ac7-bf90-822430c01414",
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/Users/loganmarkewich/Library/Caches/pypoetry/virtualenvs/llama-index-bXUwlEfH-py3.11/lib/python3.11/site-packages/tqdm/auto.py:21: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n",
      "  from .autonotebook import tqdm as notebook_tqdm\n"
     ]
    }
   ],
   "source": [
    "from llama_index.core import SimpleDirectoryReader\n",
    "\n",
    "documents = SimpleDirectoryReader(\"./data/paul_graham/\").load_data()"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "2183f0ce-ae84-41cb-bcfa-3f81705e24e8",
   "metadata": {},
   "source": [
    "#### Define Default LLMs "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "d06478af-01bc-43ed-aabc-30d193d6482a",
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/Users/loganmarkewich/Library/Caches/pypoetry/virtualenvs/llama-index-bXUwlEfH-py3.11/lib/python3.11/site-packages/tqdm/auto.py:21: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n",
      "  from .autonotebook import tqdm as notebook_tqdm\n"
     ]
    }
   ],
   "source": [
    "from llama_index.embeddings.openai import OpenAIEmbedding\n",
    "from llama_index.llms.openai import OpenAI\n",
    "\n",
    "llm = OpenAI(model=\"gpt-3.5-turbo\", temperature=0.3)\n",
    "embed_model = OpenAIEmbedding(model_name=\"text-embedding-3-small\")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "e21fe301-b239-4e47-b4e6-cfd9560c394f",
   "metadata": {},
   "source": [
    "#### Setup Neo4j\n",
    "\n",
    "To launch Neo4j locally, first ensure you have docker installed. Then, you can launch the database with the following docker command\n",
    "\n",
    "```\n",
    "docker run \\\n",
    "    -p 7474:7474 -p 7687:7687 \\\n",
    "    -v $PWD/data:/data -v $PWD/plugins:/plugins \\\n",
    "    --name neo4j-apoc \\\n",
    "    -e NEO4J_apoc_export_file_enabled=true \\\n",
    "    -e NEO4J_apoc_import_file_enabled=true \\\n",
    "    -e NEO4J_apoc_import_file_use__neo4j__config=true \\\n",
    "    -e NEO4JLABS_PLUGINS=\\[\\\"apoc\\\"\\] \\\n",
    "    neo4j:latest\n",
    "\n",
    "```\n",
    "From here, you can open the db at http://localhost:7474/. On this page, you will be asked to sign in. Use the default username/password of neo4j and neo4j."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "cfe4d640-afdc-434c-af06-c66b6fb27bea",
   "metadata": {},
   "outputs": [],
   "source": [
    "from llama_index.graph_stores.neo4j import Neo4jPropertyGraphStore\n",
    "\n",
    "graph_store = Neo4jPropertyGraphStore(\n",
    "    username=\"neo4j\",\n",
    "    password=\"llamaindex\",\n",
    "    url=\"bolt://localhost:7687\",\n",
    ")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "35a33a72-323a-4462-9d91-a2ae6ff04f9f",
   "metadata": {},
   "source": [
    "#### Build the Property Graph"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "67f2faf0-e079-4f1f-97d7-65a0cd64eabb",
   "metadata": {},
   "outputs": [],
   "source": [
    "from llama_index.core import PropertyGraphIndex\n",
    "\n",
    "index = PropertyGraphIndex.from_documents(\n",
    "    documents,\n",
    "    llm=llm,\n",
    "    embed_model=embed_model,\n",
    "    property_graph_store=graph_store,\n",
    "    show_progress=True,\n",
    ")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "6097923a-4cd9-4c82-9a88-4f89144e06c9",
   "metadata": {},
   "source": [
    "## Define Custom Retriever\n",
    "\n",
    "Now we define a custom retriever by subclassing `CustomPGRetriever`. \n",
    "\n",
    "#### 1. Initialization \n",
    "We initialize two pre-existing property graph retrievers: the `VectorContextRetriever` and the `TextToCypherRetriever`, as well as the cohere reranker.\n",
    "\n",
    "#### 2. Define `custom_retrieve`\n",
    "\n",
    "We then define the `custom_retrieve` function. It passes nodes through the two retrievers and gets back a final ranked list.\n",
    "\n",
    "The return type here can be a string, `TextNode`, `NodeWithScore`, or a list of one of those types."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "f69a555b-4029-4b74-9c96-6da617f651aa",
   "metadata": {},
   "outputs": [],
   "source": [
    "from llama_index.core.retrievers import (\n",
    "    CustomPGRetriever,\n",
    "    VectorContextRetriever,\n",
    "    TextToCypherRetriever,\n",
    ")\n",
    "from llama_index.core.graph_stores import PropertyGraphStore\n",
    "from llama_index.core.vector_stores.types import VectorStore\n",
    "from llama_index.core.embeddings import BaseEmbedding\n",
    "from llama_index.core.prompts import PromptTemplate\n",
    "from llama_index.core.llms import LLM\n",
    "from llama_index.postprocessor.cohere_rerank import CohereRerank\n",
    "\n",
    "\n",
    "from typing import Optional, Any, Union\n",
    "\n",
    "\n",
    "class MyCustomRetriever(CustomPGRetriever):\n",
    "    \"\"\"Custom retriever with cohere reranking.\"\"\"\n",
    "\n",
    "    def init(\n",
    "        self,\n",
    "        ## vector context retriever params\n",
    "        embed_model: Optional[BaseEmbedding] = None,\n",
    "        vector_store: Optional[VectorStore] = None,\n",
    "        similarity_top_k: int = 4,\n",
    "        path_depth: int = 1,\n",
    "        ## text-to-cypher params\n",
    "        llm: Optional[LLM] = None,\n",
    "        text_to_cypher_template: Optional[Union[PromptTemplate, str]] = None,\n",
    "        ## cohere reranker params\n",
    "        cohere_api_key: Optional[str] = None,\n",
    "        cohere_top_n: int = 2,\n",
    "        **kwargs: Any,\n",
    "    ) -> None:\n",
    "        \"\"\"Uses any kwargs passed in from class constructor.\"\"\"\n",
    "\n",
    "        self.vector_retriever = VectorContextRetriever(\n",
    "            self.graph_store,\n",
    "            include_text=self.include_text,\n",
    "            embed_model=embed_model,\n",
    "            vector_store=vector_store,\n",
    "            similarity_top_k=similarity_top_k,\n",
    "            path_depth=path_depth,\n",
    "        )\n",
    "\n",
    "        self.cypher_retriever = TextToCypherRetriever(\n",
    "            self.graph_store,\n",
    "            llm=llm,\n",
    "            text_to_cypher_template=text_to_cypher_template\n",
    "            ## NOTE: you can attach other parameters here if you'd like\n",
    "        )\n",
    "\n",
    "        self.reranker = CohereRerank(\n",
    "            api_key=cohere_api_key, top_n=cohere_top_n\n",
    "        )\n",
    "\n",
    "    def custom_retrieve(self, query_str: str) -> str:\n",
    "        \"\"\"Define custom retriever with reranking.\n",
    "\n",
    "        Could return `str`, `TextNode`, `NodeWithScore`, or a list of those.\n",
    "        \"\"\"\n",
    "        nodes_1 = self.vector_retriever.retrieve(query_str)\n",
    "        nodes_2 = self.cypher_retriever.retrieve(query_str)\n",
    "        reranked_nodes = self.reranker.postprocess_nodes(\n",
    "            nodes_1 + nodes_2, query_str=query_str\n",
    "        )\n",
    "\n",
    "        ## TMP: please change\n",
    "        final_text = \"\\n\\n\".join(\n",
    "            [n.get_content(metadata_mode=\"llm\") for n in reranked_nodes]\n",
    "        )\n",
    "\n",
    "        return final_text\n",
    "\n",
    "    # optional async method\n",
    "    # async def acustom_retrieve(self, query_str: str) -> str:\n",
    "    #     ..."
   ]
  },
  {
   "cell_type": "markdown",
   "id": "e9a7ad0b-f170-40e8-b9bd-76deb5b42bdc",
   "metadata": {},
   "source": [
    "## Test out the Custom Retriever\n",
    "\n",
    "Now let's initialize and test out the custom retriever against our data! \n",
    "\n",
    "To build a full RAG pipeline, we use the `RetrieverQueryEngine` to combine our retriever with the LLM synthesis module - this is also used under the hood for the property graph index."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "7346d827-9fd8-42bf-8bc9-548f23da68fe",
   "metadata": {},
   "outputs": [],
   "source": [
    "custom_sub_retriever = MyCustomRetriever(\n",
    "    index.property_graph_store,\n",
    "    include_text=True,\n",
    "    vector_store=index.vector_store,\n",
    "    cohere_api_key=\"...\",\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "1af69b34-184a-457a-9e98-d32b103ac0bd",
   "metadata": {},
   "outputs": [],
   "source": [
    "from llama_index.core.query_engine import RetrieverQueryEngine\n",
    "\n",
    "query_engine = RetrieverQueryEngine.from_args(\n",
    "    index.as_retriever(sub_retrievers=[custom_sub_retriever]), llm=llm\n",
    ")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "79c9e150-b02f-4819-becf-b68375558fce",
   "metadata": {},
   "source": [
    "#### Try out a 'baseline'\n",
    "\n",
    "We compare against a baseline retriever that's the vector context only."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "963306cc-3fb2-4a7d-9f1f-c8e6f66747c5",
   "metadata": {},
   "outputs": [],
   "source": [
    "base_retriever = VectorContextRetriever(\n",
    "    index.property_graph_store, include_text=True\n",
    ")\n",
    "base_query_engine = index.as_query_engine(sub_retrievers=[base_retriever])"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "3f7e5bda-b11d-4544-adb9-ebed298fe070",
   "metadata": {},
   "source": [
    "### Try out some Queries"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "a950e8b7-91db-4153-8528-47a8b91d77f6",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "The author found working on programming challenging but satisfying, as indicated by the intense effort put into the project and the sense of accomplishment derived from solving complex problems while working on the code.\n"
     ]
    }
   ],
   "source": [
    "response = query_engine.query(\"Did the author like programming?\")\n",
    "print(str(response))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "ce50271c-1d42-434f-a541-909d731bff54",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "The author enjoyed programming, as evidenced by their early experiences with computers, such as writing simple games, creating programs for predicting rocket flights, and developing a word processor. These experiences indicate a genuine interest and enjoyment in programming activities.\n"
     ]
    }
   ],
   "source": [
    "response = base_query_engine.query(\"Did the author like programming?\")\n",
    "print(str(response))"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "llama-index-bXUwlEfH-py3.11",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
